use super::{Node, NodeCodegen, OnnxIntoNode, try_convert_onnx_node};
use crate::burn::{BurnImports, Scope, TensorType, Type};

use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

/// Generate inline code for a loop body subgraph
///
/// Converts an OnnxGraph into a TokenStream that executes in each loop iteration.
/// Returns the body code
fn generate_loop_body_code<PS: PrecisionSettings + 'static>(
    subgraph: &onnx_ir::OnnxGraph,
    scope: &mut Scope,
    node_position: usize,
) -> TokenStream {
    let mut body = quote! {};

    // Helper to extract tensor types
    fn to_tensor(ty: Type) -> Option<TensorType> {
        match ty {
            Type::Tensor(tensor) => Some(tensor),
            _ => None,
        }
    }

    // Convert ONNX nodes to Burn nodes
    let burn_nodes: Vec<_> = subgraph
        .nodes
        .iter()
        .map(|node| {
            try_convert_onnx_node::<PS>(node.clone())
                .unwrap_or_else(|| panic!("Unsupported op in loop body: {}", node.name()))
        })
        .collect();

    // Register subgraph inputs in scope (they reference loop variables)
    for input in &subgraph.inputs {
        if let Some(tensor) = to_tensor(Type::from(input)) {
            scope.tensor_register_variable(&tensor, node_position);
        }
    }

    // Build scope for subgraph nodes: register outputs and future uses
    for (idx, burn_node) in burn_nodes.iter().enumerate() {
        let subgraph_node_pos = node_position + idx + 1;

        // Register node outputs
        for output in burn_node.output_types() {
            if let Some(tensor) = to_tensor(output) {
                scope.tensor_register_variable(&tensor, subgraph_node_pos);
            }
        }

        // Register future uses of node inputs
        for input in burn_node.input_types() {
            if let Some(tensor) = to_tensor(input) {
                scope.tensor_register_future_use(&tensor, subgraph_node_pos - 1);
            }
        }
    }

    // Register future uses for subgraph outputs
    for output in &subgraph.outputs {
        if let Some(tensor) = to_tensor(Type::from(output)) {
            scope.tensor_register_future_use(&tensor, node_position + burn_nodes.len());
        }
    }

    // Generate forward code for each node
    for (idx, burn_node) in burn_nodes.iter().enumerate() {
        let node_code = burn_node.forward(scope, node_position + idx + 1);
        body.extend(node_code);
    }

    body
}

/// Loop node - iterative execution with loop-carried dependencies and scan outputs
///
/// The Loop operation executes a body subgraph for a specified number of iterations,
/// carrying state between iterations. Can also collect intermediate values as scan outputs.
///
/// Per ONNX spec:
/// - Body inputs: [iteration_num, condition, loop_carried_dependencies...]
/// - Body outputs: [condition, loop_carried_dependencies..., scan_outputs...]
/// - Loop outputs: [final_loop_carried_deps..., scan_outputs_concatenated...]
#[derive(Debug, Clone)]
pub struct LoopNode {
    pub max_trip_count: Type, // M - optional max iteration count
    pub condition: Type,      // cond - initial loop condition
    pub v_initial: Vec<Type>, // Loop-carried dependency initial values
    pub outputs: Vec<Type>,   // Final loop-carried deps + scan outputs
    pub body: onnx_ir::OnnxGraph,
    pub num_loop_carried_outputs: usize, // Number of loop-carried dependency outputs
}

impl LoopNode {
    pub fn new(
        max_trip_count: Type,
        condition: Type,
        v_initial: Vec<Type>,
        outputs: Vec<Type>,
        body: onnx_ir::OnnxGraph,
        num_loop_carried_outputs: usize,
    ) -> Self {
        Self {
            max_trip_count,
            condition,
            v_initial,
            outputs,
            body,
            num_loop_carried_outputs,
        }
    }
}

impl<PS: PrecisionSettings + 'static> NodeCodegen<PS> for LoopNode {
    fn output_types(&self) -> Vec<Type> {
        self.outputs.clone()
    }

    fn input_types(&self) -> Vec<Type> {
        let mut inputs = vec![self.max_trip_count.clone(), self.condition.clone()];
        inputs.extend(self.v_initial.clone());
        inputs
    }

    fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
        // Extract max trip count (can be scalar or rank-1 tensor)
        let max_trip_count = match &self.max_trip_count {
            Type::Scalar(scalar) => {
                let name = &scalar.name;
                quote! { #name }
            }
            Type::Tensor(tensor) if tensor.rank == 1 => {
                // Rank-1 tensor with single element - extract the scalar value
                let name = &tensor.name;
                quote! { #name.clone().into_scalar().elem::<i64>() }
            }
            _ => panic!("Loop max_trip_count must be scalar i64 or rank-1 tensor"),
        };

        // Extract initial condition (can be scalar or rank-1 tensor)
        let initial_cond = match &self.condition {
            Type::Scalar(scalar) => {
                let name = &scalar.name;
                quote! { #name }
            }
            Type::Tensor(tensor) if tensor.rank == 1 => {
                // Rank-1 tensor with single element - extract the scalar value
                let name = &tensor.name;
                quote! { #name.clone().into_scalar().elem::<bool>() }
            }
            _ => panic!("Loop condition must be scalar bool or rank-1 tensor"),
        };

        // Body inputs: [iter_num, cond_in, v_in...]
        // Body outputs: [cond_out, v_out..., scan_outputs...]
        //
        // Per ONNX spec:
        // - self.v_initial.len() = number of loop-carried deps passed as input
        // - self.num_loop_carried_outputs = number of loop-carried deps that are output/modified
        // - self.outputs.len() = num_loop_carried_outputs + num_scan_outputs

        // Number of loop-carried dependencies passed as input
        let num_loop_vars = self.v_initial.len();

        // Number of loop-carried dependencies that are actually output/modified
        let num_loop_carried_outputs = self.num_loop_carried_outputs;

        // Number of scan outputs (collected from each iteration)
        let num_scan_outputs = self.outputs.len() - num_loop_carried_outputs;

        // Body should have 2 + num_loop_vars inputs (iter, cond, v_in...)
        assert_eq!(
            self.body.inputs.len(),
            2 + num_loop_vars,
            "Loop body should have {} inputs, got {}",
            2 + num_loop_vars,
            self.body.inputs.len()
        );

        // Body should have 1 + num_loop_carried_outputs + num_scan_outputs outputs
        // [cond_out, v_out..., scan_outputs...]
        let expected_body_outputs = 1 + num_loop_carried_outputs + num_scan_outputs;
        assert_eq!(
            self.body.outputs.len(),
            expected_body_outputs,
            "Loop body should have {} outputs (1 cond + {} loop-carried + {} scan), got {}",
            expected_body_outputs,
            num_loop_carried_outputs,
            num_scan_outputs,
            self.body.outputs.len()
        );

        // Initialize loop-carried dependency variables
        let mut init_stmts = quote! {};
        let loop_var_names: Vec<_> = self
            .body
            .inputs
            .iter()
            .skip(2) // Skip iter and cond_in
            .map(|arg| syn::Ident::new(&arg.name, proc_macro2::Span::call_site()))
            .collect();

        for (idx, initial_value) in self.v_initial.iter().enumerate() {
            let var_name = &loop_var_names[idx];
            let init_value = match initial_value {
                Type::Tensor(tensor) => {
                    let tensor_name = &tensor.name;
                    quote! { #tensor_name }
                }
                Type::Scalar(scalar) => {
                    let scalar_name = &scalar.name;
                    quote! { #scalar_name }
                }
                _ => panic!("Unsupported loop-carried dependency type"),
            };

            // Only variables that are updated by the loop body outputs need to be mutable
            // Read-only variables are cloned at the start of each iteration via shadowing
            if idx < num_loop_carried_outputs {
                init_stmts.extend(quote! {
                    let mut #var_name = #init_value;
                });
            } else {
                init_stmts.extend(quote! {
                    let #var_name = #init_value;
                });
            }
        }

        // Initialize iteration counter and condition
        let iter_name = syn::Ident::new(&self.body.inputs[0].name, proc_macro2::Span::call_site());
        let cond_in_name =
            syn::Ident::new(&self.body.inputs[1].name, proc_macro2::Span::call_site());

        // Create a bool variable for the while condition (handles both scalar and rank-1 tensor)
        let cond_bool_name = syn::Ident::new(
            &format!("{}_bool", self.body.inputs[1].name),
            proc_macro2::Span::call_site(),
        );

        // For tensors, initial_cond already has .into_scalar() applied
        // For scalars, it's just the variable name
        // Either way, we can use it as-is for the bool variable
        init_stmts.extend(quote! {
            let mut #iter_name = 0i64;
            let mut #cond_bool_name = #initial_cond;
        });

        // cond_in_name holds the tensor/scalar version for passing to the loop body
        let cond_var_init = match &self.condition {
            Type::Scalar(scalar) => {
                let name = &scalar.name;
                quote! {
                    let mut #cond_in_name = #name;
                }
            }
            Type::Tensor(tensor) => {
                let name = &tensor.name;
                quote! {
                    let mut #cond_in_name = #name;
                }
            }
            _ => panic!("Unsupported condition type in Loop node"),
        };

        init_stmts.extend(cond_var_init);

        // Initialize scan output accumulators (Vec for collecting values from each iteration)
        let scan_vec_names: Vec<_> = (0..num_scan_outputs)
            .map(|i| {
                let scan_idx = 1 + num_loop_carried_outputs + i; // +1 for cond_out
                let output_name = &self.body.outputs[scan_idx].name;
                syn::Ident::new(
                    &format!("{}_scan_vec", output_name),
                    proc_macro2::Span::call_site(),
                )
            })
            .collect();

        // Initialize Vec with appropriate type for scalar vs tensor scan outputs
        for (i, vec_name) in scan_vec_names.iter().enumerate() {
            let scan_idx = 1 + num_loop_carried_outputs + i;
            let body_output_type = Type::from(&self.body.outputs[scan_idx]);

            match body_output_type {
                Type::Scalar(_) => {
                    // For scalars, create a Vec of the scalar type (e.g., Vec<f32>, Vec<i64>)
                    // This avoids creating N tensors in the loop
                    init_stmts.extend(quote! {
                        let mut #vec_name = Vec::new();
                    });
                }
                Type::Tensor(_) => {
                    // For tensors, we still need Vec<Tensor> for concatenation
                    init_stmts.extend(quote! {
                        let mut #vec_name = Vec::new();
                    });
                }
                _ => panic!("Unsupported scan body output type"),
            }
        }

        // For read-only variables, shadow them with a clone at the start of each iteration
        // This creates a new binding that shadows the old one for the loop body
        let mut pre_body_stmts = quote! {};
        for (idx, var_name) in loop_var_names
            .iter()
            .enumerate()
            .skip(num_loop_carried_outputs)
        {
            // Check if this is a tensor type that needs cloning
            if let Type::Tensor(_) = &self.v_initial[idx] {
                pre_body_stmts.extend(quote! {
                    let #var_name = #var_name.clone();
                });
            }
        }

        // Generate loop body code
        // Constants in the loop body are automatically registered as model fields
        // by BurnGraph::collect_all_fields() which recursively processes subgraphs
        let body_code = generate_loop_body_code::<PS>(&self.body, scope, node_position);

        // Extract condition output and loop-carried dependency outputs from body
        let cond_out_name =
            syn::Ident::new(&self.body.outputs[0].name, proc_macro2::Span::call_site());

        // Update loop variables from body outputs
        // Only the first num_loop_carried_outputs loop variables have corresponding body outputs
        let mut update_stmts = quote! {};

        // Update condition (cond_out_name is always from the loop body output, which matches the body input type)
        update_stmts.extend(quote! {
            #cond_in_name = #cond_out_name;
        });

        // Update the bool variable from the tensor/scalar
        let cond_bool_update = match &self.condition {
            Type::Scalar(_) => {
                quote! {
                    #cond_bool_name = #cond_out_name;
                }
            }
            Type::Tensor(_) => {
                quote! {
                    #cond_bool_name = #cond_out_name.clone().into_scalar().elem::<bool>();
                }
            }
            _ => panic!("Unsupported condition type in Loop node"),
        };
        update_stmts.extend(cond_bool_update);

        // Collect scan outputs from this iteration BEFORE updating loop variables
        // This ensures we capture the body output values, not the updated loop variables
        for (i, vec_name) in scan_vec_names.iter().enumerate() {
            let scan_idx = 1 + num_loop_carried_outputs + i;
            let scan_out_name = syn::Ident::new(
                &self.body.outputs[scan_idx].name,
                proc_macro2::Span::call_site(),
            );

            // Check the body output type (not the loop output type)
            let body_output_type = Type::from(&self.body.outputs[scan_idx]);

            match body_output_type {
                Type::Scalar(_) => {
                    // Scalar body outputs: just push the raw scalar value
                    // We'll create a single tensor from the Vec after the loop completes
                    update_stmts.extend(quote! {
                        #vec_name.push(#scan_out_name);
                    });
                }
                Type::Tensor(_) => {
                    // Tensor body outputs can be pushed directly
                    update_stmts.extend(quote! {
                        #vec_name.push(#scan_out_name.clone());
                    });
                }
                _ => panic!("Unsupported scan body output type"),
            }
        }

        // Update loop-carried variables from body outputs
        for (idx, var_name) in loop_var_names
            .iter()
            .enumerate()
            .take(num_loop_carried_outputs)
        {
            let out_name = syn::Ident::new(
                &self.body.outputs[idx + 1].name,
                proc_macro2::Span::call_site(),
            );
            update_stmts.extend(quote! {
                #var_name = #out_name;
            });
        }

        update_stmts.extend(quote! {
            #iter_name += 1;
        });

        // Generate output assignments with block scoping
        // All loop variables are scoped inside the block to avoid polluting the outer scope
        // For scalar conditions, suppress warnings about unused/unread cond variable
        let allow_attr = match &self.condition {
            Type::Scalar(_) => quote! { #[allow(unused_variables, unused_assignments)] },
            _ => quote! {},
        };

        // Concatenate scan outputs after loop completes (convert Vec<Tensor> to Tensor via concat)
        let mut scan_concat_stmts = quote! {};
        let scan_output_names: Vec<_> = (0..num_scan_outputs)
            .map(|i| {
                let output_idx = num_loop_carried_outputs + i;
                let output = &self.outputs[output_idx];
                match output {
                    Type::Tensor(t) => &t.name,
                    Type::Scalar(s) => &s.name, // Scalar types (rank-0 tensors in ONNX)
                    _ => panic!("Scan outputs must be tensors or scalars, got {:?}", output),
                }
            })
            .collect();

        for (i, scan_output_name) in scan_output_names.iter().enumerate() {
            let vec_name = &scan_vec_names[i];

            // Check if this scan output came from a scalar body output
            let scan_idx = 1 + num_loop_carried_outputs + i;
            let body_output_type = Type::from(&self.body.outputs[scan_idx]);

            match body_output_type {
                Type::Scalar(_) => {
                    // Scalars: create single tensor from Vec → [N], then unsqueeze at dim 1 → [N, 1]
                    // This avoids creating N tensors in the loop
                    scan_concat_stmts.extend(quote! {
                        let #scan_output_name = Tensor::<B, 1>::from_data(#vec_name.as_slice(), &B::Device::default()).unsqueeze_dim::<2>(1);
                    });
                }
                Type::Tensor(_) => {
                    // Tensors: concatenate the Vec<Tensor>
                    scan_concat_stmts.extend(quote! {
                        let #scan_output_name = Tensor::cat(#vec_name, 0);
                    });
                }
                _ => panic!("Unsupported scan body output type"),
            }
        }

        // Collect all output names (loop-carried + scan)
        let all_output_names: Vec<_> = self
            .outputs
            .iter()
            .map(|output| match output {
                Type::Tensor(t) => &t.name,
                Type::Scalar(s) => &s.name,
                _ => panic!("Unsupported output type in Loop node"),
            })
            .collect();

        // Collect loop-carried variable names (only the ones that are output)
        let loop_carried_var_names: Vec<_> = loop_var_names
            .iter()
            .take(num_loop_carried_outputs)
            .collect();

        if self.outputs.len() == 1 && num_scan_outputs == 0 {
            // Single loop-carried output: let output_name = { ... loop code ... var_name };
            let output_name = &all_output_names[0];
            let var_name = &loop_carried_var_names[0];

            quote! {
                #allow_attr
                let #output_name = {
                    #init_stmts

                    while #iter_name < #max_trip_count && #cond_bool_name {
                        #pre_body_stmts
                        #body_code
                        #update_stmts
                    }

                    #var_name
                };
            }
        } else {
            // Multiple outputs (loop-carried + scan):
            // let (out1, out2, ...) = { ... loop code + concat ... (var1, var2, ..., scan1, scan2, ...) };

            quote! {
                #allow_attr
                let (#(#all_output_names),*) = {
                    #init_stmts

                    while #iter_name < #max_trip_count && #cond_bool_name {
                        #pre_body_stmts
                        #body_code
                        #update_stmts
                    }

                    // Concatenate scan outputs
                    #scan_concat_stmts

                    // Return all outputs (loop-carried + scan)
                    (#(#loop_carried_var_names),*, #(#scan_output_names),*)
                };
            }
        }
    }

    fn into_node(self) -> Node<PS> {
        Node::Loop(self)
    }

    fn register_imports(&self, imports: &mut BurnImports) {
        // Register Vec for scan output accumulators
        let num_scan_outputs = self.outputs.len() - self.num_loop_carried_outputs;
        if num_scan_outputs > 0 {
            imports.register("alloc::vec::Vec");
        }

        // Register imports from body nodes
        for onnx_node in &self.body.nodes {
            if let Some(burn_node) = try_convert_onnx_node::<PS>(onnx_node.clone()) {
                burn_node.register_imports(imports);
            }
        }
    }
}

impl OnnxIntoNode for LoopNode {
    fn from_onnx(node: onnx_ir::Node) -> Self {
        let onnx_ir::Node::Loop(n) = node else {
            panic!("Expected Loop node");
        };

        // Extract M (max trip count) and cond (condition) - first two inputs
        let max_trip_count = Type::from(n.inputs.first().unwrap());
        let condition = Type::from(&n.inputs[1]);

        // Get body graph from config
        let body = n.config.body.clone();

        // Loop-carried dependencies are inputs after M and cond
        let v_initial: Vec<Type> = n.inputs.iter().skip(2).map(Type::from).collect();

        // Per ONNX spec:
        // - Body inputs: [iteration_num, cond_in, v_in_1, ..., v_in_N]
        // - Body outputs: [cond_out, v_out_1, ..., v_out_N, scan_out_1, ..., scan_out_K]
        // - Loop node outputs: [v_final_1, ..., v_final_N, scan_1, ..., scan_K]
        //
        // Where N = number of loop-carried dependencies = v_initial.len()
        // And K = number of scan outputs = outputs.len() - N
        let num_loop_carried_outputs = v_initial.len();

        // Validate that body outputs match expected structure
        let expected_body_outputs =
            1 + num_loop_carried_outputs + (n.outputs.len() - num_loop_carried_outputs);
        if body.outputs.len() != expected_body_outputs {
            panic!(
                "Loop body output mismatch: expected {} outputs (1 cond + {} loop-carried + {} scan), got {}",
                expected_body_outputs,
                num_loop_carried_outputs,
                n.outputs.len() - num_loop_carried_outputs,
                body.outputs.len()
            );
        }

        // Convert node outputs to Type
        // Note: onnx-ir handles type inference for scan outputs (adding concat dimension)
        let outputs: Vec<Type> = n.outputs.iter().map(Type::from).collect();

        Self::new(
            max_trip_count,
            condition,
            v_initial,
            outputs,
            body,
            num_loop_carried_outputs,
        )
    }
}
